from typing_extensions import override

from dataclasses import dataclass

from pydantic import BaseModel
from ten_ai_base.message import (
    ModuleError,
    ModuleErrorCode,
    ModuleType,
)
from ten_ai_base.struct import TTSTextInput
from ten_ai_base.tts2 import AsyncTTS2BaseExtension
from ten_runtime import (
    AsyncTenEnv,
)


@dataclass
class DefaultTTSConfig(BaseModel):
    pass


class DefaultTTSExtension(AsyncTTS2BaseExtension):
    def __init__(self, name: str):
        super().__init__(name)
        self.config: DefaultTTSConfig | None = None

    async def on_init(self, ten_env: AsyncTenEnv) -> None:
        await super().on_init(ten_env)

        ten_env.log_info("DefaultTTSExtension on_init")

        config_json, _ = await self.ten_env.get_property_to_json("")

        try:
            self.config = DefaultTTSConfig.model_validate_json(config_json)
        except Exception as e:
            await self._handle_error(e)

    @override
    def vendor(self) -> str:
        """
        Get the vendor name of the TTS implementation.
        This is used for metrics and error reporting.
        """
        raise NotImplementedError(
            "This method should be implemented in subclasses."
        )

    @override
    async def request_tts(self, t: TTSTextInput) -> None:
        """
        Called when a new input item is available in the queue. Override this method to implement the TTS request logic.
        Use send_audio_out to send the audio data to the output when the audio data is ready.
        """
        raise NotImplementedError(
            "request_tts must be implemented in the subclass"
        )

    @override
    def synthesize_audio_sample_rate(self) -> int:
        """
        Get the input audio sample rate in Hz.
        """
        raise NotImplementedError(
            "This method should be implemented in subclasses."
        )

    async def _handle_error(self, error: Exception):
        self.ten_env.log_error(f"Default error: {error}")
        await self.send_tts_error(
            ModuleError(
                module=ModuleType.TTS,
                code=ModuleErrorCode.FATAL_ERROR.value,
                message=str(error),
            ),
        )
